import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math


def initilize(layer):
    if isinstance(layer, nn.Conv2d):
        nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu')
    elif isinstance(layer, nn.Linear):
        nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu')
        if layer.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(layer.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.normal_(layer.bias, -bound, bound)

class Muer_Conv_InputXY(nn.Module):
    def __init__(self, cfg, *args, **kwargs):
        super().__init__()
        self.num_classes = kwargs['num_classes']
        self.num_class_embedding = cfg.muer.num_class_embedding

        self.label_embedding = nn.Sequential(
            nn.Embedding(self.num_classes, self.num_class_embedding),
            nn.Linear(self.num_class_embedding, 1*32*32))

        self.convs = nn.Sequential(
            nn.Conv2d(6, 64, 4, 2, 1, bias=False), # 16
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 64*2, 4, 2, 1, bias=False), # 8
            nn.BatchNorm2d(64*2, momentum=0.1,  eps=0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64*2, 64*4, 4, 2,1, bias=False), # 4
            nn.BatchNorm2d(64*4, momentum=0.1,  eps=0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64*4, 64*8, 4, 2, 1, bias=False), # 2
            nn.BatchNorm2d(64*8, momentum=0.1,  eps=0.8),
            nn.LeakyReLU(0.2, inplace=True), 
            nn.Flatten(),
            nn.Dropout(0.4),
            nn.Linear(64*8*2, self.num_classes),
            nn.ReLU(inplace=True)
            )
        
        self.convs.apply(initilize)
        self.label_embedding.apply(initilize)

    def forward(self, x, label, *args, **kwargs):
        label_output = self.label_embedding(label)
        label_output = label_output.view(-1, 1, 32, 32)
        concat = torch.cat((x, label_output), dim=1)
        l = self.convs(concat)
        return l

class Muer_Linear_InputXY(nn.Module):
    def __init__(self, cfg, *args, **kwargs):
        super().__init__()
        self.num_classes = kwargs['num_classes']
        self.num_class_embedding = cfg.muer.num_class_embedding
        self.input_dim = np.prod(cfg.Classifier.input_shape)
        self.hidden_dims = cfg.Classifier.hidden_dims

        self.label_embedding = nn.Sequential(
            nn.Embedding(self.num_classes, self.num_class_embedding),
            nn.Linear(self.num_class_embedding, 16))

        self.latent = nn.Sequential(nn.Linear(self.input_dim, 64), nn.ReLU())

        self.model = nn.Sequential()
        for input_dim, output_dim in zip(
                [64 + 16] + self.hidden_dims[:-1],
                self.hidden_dims
            ):
            self.model.append(nn.Linear(input_dim, output_dim))
            self.model.append(nn.ReLU())
        self.model.append(nn.Linear(self.hidden_dims[-1], self.num_classes))
        
        self.model.apply(initilize)
        self.latent.apply(initilize)
        self.label_embedding.apply(initilize)

    def forward(self, x, label, *args, **kwargs):
        label_output = self.label_embedding(label)
        latent_output = self.latent(x)
        concat = torch.cat((latent_output, label_output), dim=1)
        l = self.model(concat)
        return l


class AlphaModule(nn.Module):
    def __init__(self, alpha):
        super(AlphaModule, self).__init__()
        self.alpha = nn.Parameter(torch.Tensor(1).uniform_(0.01, 0.1)) # the -alpha in the paper
        
    def forward(self):
        return torch.exp(self.alpha)
        
class LambdaModule:
    def __init__(self, net, samples):
        self.net = net
        self.device = next(self.net.parameters()).device
        self.samples = samples
        self.parameter_names = sorted([name for name, _ in net.named_parameters()])
        self.cal_Lambda()
        self.cal_Lambda_tilde()

    def get_grad_matrix(self, samples):
        grads = []
        for sample in samples:
            self.net.zero_grad()
            sample = sample.unsqueeze(0)
            sample = sample.to(self.device)
            phi = self.net(sample)
            self.net.zero_grad()
            phi[0,0].backward()
            grad = []
            named_parameters = dict(self.net.named_parameters())
            for name in self.parameter_names:
                element = (named_parameters[name].grad.detach() * named_parameters[name].detach()).sum()
                grad.append(element.item())
            grads.append(grad)
        return torch.tensor(grads)

    def cal_Lambda(self):
        Nabla = self.get_grad_matrix(self.samples)
        Nabla = Nabla.detach()
        with torch.no_grad():
            samples = self.samples.to(self.device)
            phis = self.net(samples)
            phis = phis[:,0:1]
            phis = phis.detach().cpu()
            # Lambda_Zeta, _ = torch.solve(phis, Nabla)
            print("Rank(Nabla)=", torch.linalg.matrix_rank(Nabla))
            print("Shape(Nabla)=", Nabla.shape)
            print("Shape(phis)=", phis.shape)
            Lambda = torch.linalg.lstsq(Nabla, phis).solution
            Lambda = Lambda.squeeze(-1)
            self.Lambda = Lambda
        return self.Lambda_dict

    @property
    def Lambda_dict(self):
        return dict(zip(self.parameter_names, self.Lambda))

    @property
    def Lambda_tilde_dict(self):
        return dict(zip(self.parameter_names, self.Lambda_tilde))

    def save(self):
        pass

    def cal_Lambda_tilde(self, delta = 0.01):
        with torch.no_grad():
            max_val = torch.max(self.Lambda)
            Lambda_tilde = self.Lambda.clone()
            Lambda_tilde[Lambda_tilde < (max_val - delta)] *= 0.0
            self.Lambda_tilde = Lambda_tilde
        return self.Lambda_tilde_dict

    def cal_Lambda_prime(self, alpha):
        device = alpha.device
        Lambda_tilde = self.Lambda_tilde.to(device)
        Lambda = self.Lambda.to(device)
        Lambda_prime = Lambda_tilde * torch.exp(alpha*(2*Lambda - 1))
        return Lambda_prime

    def Lambda_prime_dict(self, alpha):
        Lambda_prime = self.cal_Lambda_prime(alpha)
        return dict(zip(self.parameter_names, Lambda_prime))